"""
This file contains only an EXAMPLE, to show how to create a distribution of ClassificationTasks that load images from
individual files (from the Omniglot dataset), rather than from a single pickle file.

This is the correct approach to take for any dataset larger than a few GBs.
"""

import pickle
import os
import numpy as np
import tensorflow as tf
import cv2

from pyMeta.core.task import ClassificationTaskFromFiles
from pyMeta.core.task_distribution import TaskDistribution

charomniglot_trainX = []
charomniglot_trainY = []
charomniglot_testX = []
charomniglot_testY = []


def load_and_process_fn(filename):
    image = cv2.imread(filename, cv2.IMREAD_GRAYSCALE)
    image = cv2.resize(image, (28, 28))
    image = np.expand_dims(image, -1)
    return image

def create_omniglot_from_files_task_distribution(path_to_dataset,
                                                 batch_size=32,
                                                 num_training_samples_per_class=10,
                                                 num_test_samples_per_class=-1,
                                                 num_training_classes=20,
                                                 meta_batch_size=5):
    """
    Returns a TaskDistribution that, on each reset, samples a different set of omniglot characters.

    Arguments:
    path_to_dataset : string
        Path to the Omniglot dataset. The folder must contain the two standard subfolders 'images_background'
        (training classes) and 'images_evaluation' (test classes). Each of these subfolders should contain a number
        of subfolders, one for each Omniglot alphabet, and each of the alphabet folder must contain a folder with all
        the repetitions for the same character.
    batch:size : int
        Default size of minibatches generated by the tasks, if minibatches are sampled from them without specifying
        a batch size.
    num_training_samples_per_class : int
        If -1, sample from the whole dataset. If >=1, the dataset will re-sample num_training_samples_per_class
        for each class at each reset, and sample minibatches exclusively from them, until the next reset.
        This is useful for, e.g., k-shot classification.
    num_test_samples_per_class : int
        Same as `num_training_samples_per_class'. Used to generate test sets for tasks on reset().
    num_training_classes : int
        If -1, use all the classes in `y'. If >=1, the dataset will re-sample `num_training_classes' at
        each reset, and sample minibatches exclusively from them, until the next reset.
    meta_batch_size : int
        Default size of meta-batches generated by the tasks, if they are sampled from them without specifying
        a meta batch size.

    Returns:
    metatrain_task_distribution : TaskDistribution
        TaskDistribution object for use during training
    metaval_task_distribution : TaskDistribution
        TaskDistribution object for use during model validation
    metatest_task_distribution : TaskDistribution
        TaskDistribution object for use during testing
    """

    # Pre-load all the filenames and their corresponding label (within each alphabet dataset).
    def load_metadataset(path):
        allX = []
        allY = []

        alphabets_folders = os.listdir(path)
        for alphabet_index, alphabet in enumerate(alphabets_folders):
            X = []
            Y = []

            characters_folders = os.listdir(os.path.join(path, alphabet))
            for char_id, char in enumerate(characters_folders):
                samples = os.listdir(os.path.join(path, alphabet, char))
                for s in samples:
                    if os.path.splitext(s)[1]=='.png':
                        """
                        image = cv2.imread(os.path.join(folder, alphabet, char, s), cv2.IMREAD_GRAYSCALE)
                        if resize > 0:
                            image = cv2.resize(image, (resize, resize))
                        """
                        X.append(os.path.join(path, alphabet, char, s))
                        Y.append(char_id)

            allX.append(X)
            allY.append(np.asarray(Y))

        return allX, allY

    # metatrain_filenames[dataset_index][sample_index] (sample_index includes all classes and their repetitions)
    metatrain_filenames, metatrain_labels = load_metadataset(os.path.join(path_to_dataset,'images_background'))
    metatest_filenames, metatest_labels = load_metadataset(os.path.join(path_to_dataset,'images_evaluation'))

    # TODO: Possibly: merge lists, and re-split in different proportions? (e.g., current Omniglot 36-14 instead of 30-20)


    # Create a single large dataset with all sub-datasets' classes, each for train and test, and rename the targets
    # appropriately
    trX = []
    trY = []
    teX = []
    teY = []

    cur_label_start = 0
    for alphabet_i in range(len(metatrain_labels)):
        metatrain_labels[alphabet_i] += cur_label_start
        trX.extend(metatrain_filenames[alphabet_i])
        trY.extend(metatrain_labels[alphabet_i])
        cur_label_start += len(set(metatrain_labels[alphabet_i]))

    cur_label_start = 0
    for alphabet_i in range(len(metatest_labels)):
        metatest_labels[alphabet_i] += cur_label_start
        teX.extend(metatest_filenames[alphabet_i])
        teY.extend(metatest_labels[alphabet_i])
        cur_label_start += len(set(metatest_labels[alphabet_i]))

    trY = np.asarray(trY, dtype=np.int64)
    teY = np.asarray(teY, dtype=np.int64)


    # Create ClassificationTask objects
    metatrain_tasks_list = [ClassificationTaskFromFiles(trX,
                                               trY,
                                               num_training_samples_per_class,
                                               num_test_samples_per_class,
                                               num_training_classes,
                                               split_train_test=-1,
                                               input_parse_fn=load_and_process_fn)] # defaults to num_train / (num_train+num_test)
    metatest_tasks_list = [ClassificationTaskFromFiles(teX,
                                              teY,
                                              num_training_samples_per_class,
                                              num_test_samples_per_class,
                                              num_training_classes,
                                              split_train_test=-1,
                                              input_parse_fn=load_and_process_fn)]

    # Create TaskDistribution objects that wrap the ClassificationTask objects to produce meta-batches of tasks
    metatrain_task_distribution = TaskDistribution(tasks=metatrain_tasks_list,
                                                   task_probabilities=[1.0],
                                                   batch_size=meta_batch_size,
                                                   sample_with_replacement=True)

    metatest_task_distribution = TaskDistribution(tasks=metatest_tasks_list,
                                                  task_probabilities=[1.0],
                                                  batch_size=meta_batch_size,
                                                  sample_with_replacement=True)

    # TODO: split into validation and test!
    return metatrain_task_distribution, metatest_task_distribution, metatest_task_distribution
